import math
import time
from class_family import *

import warnings
warnings.filterwarnings('ignore')

def simulate_moments(param_choice, param_given, param_random, disp = False):
	'''
	simulate and obtain moments
	fit 2->3, income and preference heterogeneity
	'''

	# choice variables: 8 parameters
	pi_n = param_choice[0]
	pi_nq = param_choice[1]

	theta = param_choice[2]
	alpha = param_choice[3]
	rho = param_choice[4]

	sig_eps = param_choice[5]
	eps_cutoff = param_choice[6]

	cor_lny_eps = param_choice[7]

	# given variables: 5 parameters
	mu_lny = param_given['mu_lny']
	sig_lny = param_given['sig_lny']

	mu_eps = param_given['mu_eps']

	pi_s = param_given['pi_s']
	pi_q = param_given['pi_q']
	tau = param_given['tau']
	gamma = param_given['gamma']

	# randomization variables
	n_draws = param_random['n_draws']
	seed = param_random['seed'] # seed = a integer or None

	mu = [mu_lny, mu_eps]
	cov_lny_eps = cor_lny_eps * sig_lny * sig_eps

	cov = [[sig_lny**2,cov_lny_eps],[cov_lny_eps,sig_eps**2]]
	#cov = [[sig_lny**2,0],[0,sig_eps**2]]

		# draw from a joint-normal distribution of log(income) and eps
	rng = np.random.default_rng(seed) # use a pre-specified seed for replication
	lny_eps_draws = rng.multivariate_normal(mu, cov, n_draws)

	# container of simulation results
	simuData_list = []

	# create an instance of FamilyRation()
	FR = FamilyRation_ump(
		pi_n = pi_n, pi_nq = pi_nq, pi_s = pi_s, pi_q = pi_q, 
		theta = theta, rho = rho, gamma = gamma, 
		alpha = alpha, y = 10
		)    

	# main simulation, repeated by "n_draws" times
	if disp: 
		print("Running a simulation ... " + str(n_draws) + " iterations")
	for i in range(n_draws):
		
		[lny, eps] = lny_eps_draws[i]
		
		if disp: 
			print("---------------------")
			print("Iteration: ", i)
			print("[lny, eps]: ", lny, eps)
			print("---------------------")
		
		# income exponential and scaling
		y = np.exp(lny)/1000
		
		# add a proportion of income to child price
		pi_n_tau = pi_n + tau * y
		
		#params = [pi_n_tau] + params_2

		# check income
		if y - pi_n_tau * 3 < 0.5:
			if disp: print("Insufficient income to bear three children, pass")
			continue

		# adjust preference parameter alpha
		theta_new = theta + eps
		if theta_new < 0:
			theta_new = 0.01
		elif theta_new > 1:
			theta_new = 0.99

		# update FR
		FR.y = y
		FR.pi = np.array( [pi_n_tau, pi_q, pi_s, pi_nq] )
		FR.a = np.array( [alpha, rho, theta_new, gamma] )
		
		# evaluate optimal N, mother type, and QQ effects
			
			# symbolic version
		#FR.evalDiscreteN(disp = False)

			# fast numeric: 8 times faster than the symbolic version
		FR.evalDiscreteN0(disp = False)

		FR.evalQQ_DiscreteN()

		simuData = FR.getSimuData()

		simuData['lny'] = lny
		simuData['eps'] = eps
		
		# add the list
		simuData_list.append(simuData)

	# generate simulation dataframe
	simu_df = pd.concat(simuData_list).reset_index(drop=True)

		# filter out n_opt == 1
	simu_df = simu_df[simu_df['n_opt'] > 1]

		# generate type_2to3 variable
	simu_df = simu_df.assign(
		type_2to3 = lambda x: np.where(
				x['n_opt'] == 2, 'A', np.where((x['n_opt'] == 3) & (x['eps'] < eps_cutoff), 'B', 'C') 
			)   
	)

		# n_observed: 2 vs. 3+
	simu_df = simu_df.assign(
		n_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), 2, 3 ), # 3 means 3+
		n3 = lambda x: x['type_2to3'] == 'C' # 3+ indicator
	)

		# q_observed: only for types A and B, for whom we know the rationing pattern
	simu_df = simu_df.assign(
		q_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), x['q_n2'], np.nan ),
		lnq_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), np.log(x['q_n2']), np.nan )
	)

		# expenditure shares: only
	simu_df = simu_df.assign(
		exp_n_AB = lambda x: np.where(
				x['type_2to3'].isin(['A', 'B']), 
				pi_n * x['n_observed'] / x['y'], # only monetary cost
				np.nan 
			), 
		exp_nq_AB = lambda x: np.where(
				x['type_2to3'].isin(['A', 'B']), 
				pi_nq * x['n_observed'] * x['q_observed'] / x['y'], 
				np.nan 
			)
	)

		# filter out q_observed == 0
	simu_df = simu_df[ simu_df['q_observed'] != 0  ]

	# generate moments

		# causal effects
	beta_A_2to3 = simu_df[simu_df['type_2to3'] == 'A']["qq_e_DN_2to3"].mean()

	beta_B_2to3 = simu_df[simu_df['type_2to3'] == 'B']["qq_e_DN_2to3"].mean()

		# type shares
	type_2to3_count = simu_df['type_2to3'].value_counts()

	obs_2to3 = len(simu_df)

	PA_2to3 = type_2to3_count['A'] / obs_2to3

	PB_2to3 = type_2to3_count['B'] / obs_2to3  # PB represents PB_x in the Appendix

		# descriptives and correlations in the case of non-rationing

		# income ~ fertility correlation
	cor_lny_n3 = simu_df['lny'].corr(simu_df['n3'])

		# for the types A and B
	simu_df_AB = simu_df[ simu_df['type_2to3'].isin(['A', 'B']) ]

	exp_n_AB = simu_df_AB['exp_n_AB'].mean()
	exp_nq_AB = simu_df_AB['exp_nq_AB'].mean()

	cor_lny_lnq_AB = simu_df_AB['lny'].corr(simu_df_AB['lnq_observed'])

	# collect moments
	moments = np.array(
		[beta_A_2to3, beta_B_2to3, PA_2to3, PB_2to3,
		exp_n_AB, exp_nq_AB, cor_lny_n3, cor_lny_lnq_AB]
		)

	return moments






def simulate_objective(simulate_moments, param_choice, param_given, param_random, 
	moments_real, W, disp = False):
	'''
	1. simulate_moments()
	2. return the objective function value
	'''

	# run the simulation
	if disp:
		start_time = time.time()
		print("-------------------------------------------------------")
		print("Evaluating the objective function via a simulation ... ")
		print("Parameter names: pi_n pi_nq theta alpha rho sig_eps eps_cutoff cor_lny_eps")
		print("param_choice:", param_choice )

	moments_simu = simulate_moments(param_choice, param_given, param_random, False)

	# calculate the distance between moments
	M = moments_simu - moments_real

	# calculate the value of the objective function 
	objective = np.sum(np.dot(M, np.dot(W, M.T)))

	if disp:
		print("Real moments:", moments_real)
		print("Simu moments:", moments_simu)
		print("Distance    :", M)
		print("Objective function value: ", objective)
		print("Simulation time: %s seconds" % (time.time() - start_time))
		print("-------------------------------------------------------")

	return objective







########## for the calculation of SEs


def simulate_moments_se(param_choice, param_given, lny_eps_draws, disp = False):
	'''
	simulate and obtain moments
	fit 2->3, income and preference heterogeneity
	for the calculation of standard errors
	'''

	# choice variables: 8 parameters
	pi_n = param_choice[0]
	pi_nq = param_choice[1]

	theta = param_choice[2]
	alpha = param_choice[3]
	rho = param_choice[4]

	sig_eps = param_choice[5]
	eps_cutoff = param_choice[6]

	cor_lny_eps = param_choice[7]

	# given variables: 5 parameters
	mu_lny = param_given['mu_lny']
	sig_lny = param_given['sig_lny']

	mu_eps = param_given['mu_eps']

	pi_s = param_given['pi_s']
	pi_q = param_given['pi_q']
	tau = param_given['tau']
	gamma = param_given['gamma']


	# container of simulation results
	simuData_list = []

	# create an instance of FamilyRation()
	FR = FamilyRation_ump(
		pi_n = pi_n, pi_nq = pi_nq, pi_s = pi_s, pi_q = pi_q, 
		theta = theta, rho = rho, gamma = gamma, 
		alpha = alpha, y = 10
		)    

	# main simulation, repeated by "n_draws" times
	if disp: 
		print("Running a simulation ... " + str(n_draws) + " iterations")

	n_draws = len(lny_eps_draws)
	
	for i in range(n_draws):
		
		[lny, eps] = lny_eps_draws[i]
		
		if disp: 
			print("---------------------")
			print("Iteration: ", i)
			print("[lny, eps]: ", lny, eps)
			print("---------------------")
		
		# income exponential and scaling
		y = np.exp(lny)/1000
		
		# add a proportion of income to child price
		pi_n_tau = pi_n + tau * y
		
		# check income
		if y - pi_n_tau * 3 < 0.5:
			if disp: print("Insufficient income to bear three children, pass")
			continue

		# adjust preference parameter alpha
		theta_new = theta + eps
		if theta_new < 0:
			theta_new = 0.01
		elif theta_new > 1:
			theta_new = 0.99

		# update FR
		FR.y = y
		FR.pi = np.array( [pi_n_tau, pi_q, pi_s, pi_nq] )
		FR.a = np.array( [alpha, rho, theta_new, gamma] )
		
		# evaluate optimal N, mother type, and QQ effects
			
			# symbolic version
		#FR.evalDiscreteN(disp = False)

			# fast numeric: 8 times faster than the symbolic version
		FR.evalDiscreteN0(disp = False)

		FR.evalQQ_DiscreteN()

		simuData = FR.getSimuData()

		simuData['lny'] = lny
		simuData['eps'] = eps
		
		# add the list
		simuData_list.append(simuData)

	# generate simulation dataframe
	simu_df = pd.concat(simuData_list).reset_index(drop=True)

		# filter out n_opt == 1
	simu_df = simu_df[simu_df['n_opt'] > 1]

		# generate type_2to3 variable
	simu_df = simu_df.assign(
		type_2to3 = lambda x: np.where(
				x['n_opt'] == 2, 'A', np.where((x['n_opt'] == 3) & (x['eps'] < eps_cutoff), 'B', 'C') 
			)   
	)

		# n_observed: 2 vs. 3+
	simu_df = simu_df.assign(
		n_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), 2, 3 ), # 3 means 3+
		n3 = lambda x: x['type_2to3'] == 'C' # 3+ indicator
	)

		# q_observed: only for types A and B, for whom we know the rationing pattern
	simu_df = simu_df.assign(
		q_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), x['q_n2'], np.nan ),
		lnq_observed = lambda x: np.where(x['type_2to3'].isin(['A', 'B']), np.log(x['q_n2']), np.nan )
	)

		# expenditure shares: only
	simu_df = simu_df.assign(
		exp_n_AB = lambda x: np.where(
				x['type_2to3'].isin(['A', 'B']), 
				pi_n * x['n_observed'] / x['y'], # only monetary cost
				np.nan 
			), 
		exp_nq_AB = lambda x: np.where(
				x['type_2to3'].isin(['A', 'B']), 
				pi_nq * x['n_observed'] * x['q_observed'] / x['y'], 
				np.nan 
			)
	)

		# filter out q_observed == 0
	simu_df = simu_df[ simu_df['q_observed'] != 0  ]

	# generate moments

		# causal effects
	beta_A_2to3 = simu_df[simu_df['type_2to3'] == 'A']["qq_e_DN_2to3"].mean()

	beta_B_2to3 = simu_df[simu_df['type_2to3'] == 'B']["qq_e_DN_2to3"].mean()

		# type shares
	type_2to3_count = simu_df['type_2to3'].value_counts()

	obs_2to3 = len(simu_df)

	PA_2to3 = type_2to3_count['A'] / obs_2to3

	PB_2to3 = type_2to3_count['B'] / obs_2to3  # PB represents PB_x in the Appendix

		# descriptives and correlations in the case of non-rationing

		# income ~ fertility correlation
	cor_lny_n3 = simu_df['lny'].corr(simu_df['n3'])

		# for the types A and B
	simu_df_AB = simu_df[ simu_df['type_2to3'].isin(['A', 'B']) ]

	exp_n_AB = simu_df_AB['exp_n_AB'].mean()
	exp_nq_AB = simu_df_AB['exp_nq_AB'].mean()

	cor_lny_lnq_AB = simu_df_AB['lny'].corr(simu_df_AB['lnq_observed'])

	# collect moments
	moments = np.array(
		[beta_A_2to3, beta_B_2to3, PA_2to3, PB_2to3,
		exp_n_AB, exp_nq_AB, cor_lny_n3, cor_lny_lnq_AB]
		)

	return moments


def simulate_objective_se(simulate_moments, param_choice, param_given, lny_eps_draws, 
	moments_real, W, disp = False):
	'''
	1. simulate_moments()
	2. return the objective function value
	'''

	# run the simulation
	if disp:
		start_time = time.time()
		print("-------------------------------------------------------")
		print("Evaluating the objective function via a simulation ... ")
		print("Parameter names: pi_n pi_nq theta alpha rho sig_eps eps_cutoff cor_lny_eps")
		print("param_choice:", param_choice )

	moments_simu = simulate_moments(param_choice, param_given, lny_eps_draws, False)

	# calculate the distance between moments
	M = moments_simu - moments_real

	# calculate the value of the objective function 
	objective = np.sum(np.dot(M, np.dot(W, M.T)))

	if disp:
		print("Real moments:", moments_real)
		print("Simu moments:", moments_simu)
		print("Distance    :", M)
		print("Objective function value: ", objective)
		print("Simulation time: %s seconds" % (time.time() - start_time))
		print("-------------------------------------------------------")

	return objective


